package db

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"strconv"
	"strings"
	"time"

	"github.com/cheggaaa/pb/v3"
	"github.com/go-redis/redis/v8"
	"github.com/knqyf263/go-cpe/common"
	"github.com/knqyf263/go-cpe/naming"
	"github.com/spf13/viper"
	"github.com/vulsio/go-cve-dictionary/config"
	"github.com/vulsio/go-cve-dictionary/fetcher/jvn"
	"github.com/vulsio/go-cve-dictionary/fetcher/nvd"
	log "github.com/vulsio/go-cve-dictionary/log"
	"github.com/vulsio/go-cve-dictionary/models"
	"golang.org/x/xerrors"
)

/**
# Redis Data Structure

- Sets
  ┌─────────────────────────────────────────┬──────────────────────┬───────────────────────────────────────────────────┐
  │                  KEY                    │        MEMBER        │                      PURPOSE                      │
  └─────────────────────────────────────────┴──────────────────────┴───────────────────────────────────────────────────┘
  ┌─────────────────────────────────────────┬──────────────────────┬───────────────────────────────────────────────────┐
  │ CVE#CPE#${${part}#${vendor}#${product}} │ CVEID                │ Get Strings Key BY CPE URI                        │
  └─────────────────────────────────────────┴──────────────────────┴───────────────────────────────────────────────────┘

- Hash
  ┌──────────────────┬───────────────┬─────────────┬──────────────────────────────────────────────────┐
  │     HASH         │      FIELD    │    VALUE    │             PURPOSE                              │
  └──────────────────┴───────────────┴─────────────┴──────────────────────────────────────────────────┘
  ┌──────────────────┬───────────────┬─────────────┬──────────────────────────────────────────────────┐
  │ CVE#CVE#${CVEID} │ NVD/${JVNID}  │ ${CVEJSON}  │ Get CVEJSON by CVEID                             │
  ├──────────────────┼───────────────┼─────────────┼──────────────────────────────────────────────────┤
  │ CVE#DEP          │ NVD/JVN       │    JSON     │ TO DELETE OUTDATED AND UNNEEDED FIELD AND MEMBER │
  ├──────────────────┼───────────────┼─────────────┼──────────────────────────────────────────────────┤
  │ CVE#FETCHMETA    │ Revision      │ string      │ Get Go-Cve-Dictionary Binary Revision            │
  ├──────────────────┼───────────────┼─────────────┼──────────────────────────────────────────────────┤
  │ CVE#FETCHMETA    │ SchemaVersion │ uint        │ Get Go-Cve-Dictionary Schema Version             │
  ├──────────────────┼───────────────┼─────────────┼──────────────────────────────────────────────────┤
  │ CVE#FETCHMETA    │ LastFetchedAt │ time.Time   │ Get Go-Cve-Dictionary Last Fetched Time          │
  └──────────────────┴───────────────┴─────────────┴──────────────────────────────────────────────────┘

**/

const (
	dialectRedis = "redis"
	cveKeyFormat = "CVE#CVE#%s"
	cpeKeyFormat = "CVE#CPE#%s"
	depKey       = "CVE#DEP"
	fetchMetaKey = "CVE#FETCHMETA"
)

// RedisDriver is Driver for Redis
type RedisDriver struct {
	name string
	conn *redis.Client
}

// Name return db name
func (r *RedisDriver) Name() string {
	return r.name
}

// OpenDB opens Database
func (r *RedisDriver) OpenDB(_, dbPath string, _ bool, option Option) error {
	if err := r.connectRedis(dbPath, option); err != nil {
		return xerrors.Errorf("Failed to open DB. dbtype: %s, dbpath: %s, err: %w", dialectRedis, dbPath, err)
	}
	return nil
}

func (r *RedisDriver) connectRedis(dbPath string, option Option) error {
	var err error
	var opt *redis.Options
	if opt, err = redis.ParseURL(dbPath); err != nil {
		log.Errorf("Failed to parse url. err: %s", err)
		return err
	}
	if 0 < option.RedisTimeout.Seconds() {
		opt.ReadTimeout = option.RedisTimeout
	}
	r.conn = redis.NewClient(opt)
	return r.conn.Ping(context.Background()).Err()
}

// CloseDB close Database
func (r *RedisDriver) CloseDB() (err error) {
	if r.conn == nil {
		return
	}
	if err = r.conn.Close(); err != nil {
		return xerrors.Errorf("Failed to close DB. Type: %s. err: %w", r.name, err)
	}
	return
}

// MigrateDB migrates Database
func (r *RedisDriver) MigrateDB() error {
	return nil
}

// IsGoCVEDictModelV1 determines if the DB was created at the time of go-cve-dictionary Model v1
func (r *RedisDriver) IsGoCVEDictModelV1() (bool, error) {
	ctx := context.Background()

	exists, err := r.conn.Exists(ctx, fetchMetaKey).Result()
	if err != nil {
		return false, xerrors.Errorf("Failed to Exists. err: %w", err)
	}
	if exists == 0 {
		keys, _, err := r.conn.Scan(ctx, 0, "CVE#*", 1).Result()
		if err != nil {
			return false, fmt.Errorf("Failed to Scan. err: %s", err)
		}
		if len(keys) == 0 {
			return false, nil
		}
		return true, nil
	}

	return false, nil
}

// GetFetchMeta get FetchMeta from Database
func (r *RedisDriver) GetFetchMeta() (*models.FetchMeta, error) {
	ctx := context.Background()

	exists, err := r.conn.Exists(ctx, fetchMetaKey).Result()
	if err != nil {
		return nil, xerrors.Errorf("Failed to Exists. err: %w", err)
	}
	if exists == 0 {
		return &models.FetchMeta{GoCVEDictRevision: config.Revision, SchemaVersion: models.LatestSchemaVersion, LastFetchedAt: time.Date(1000, time.January, 1, 0, 0, 0, 0, time.UTC)}, nil
	}

	revision, err := r.conn.HGet(ctx, fetchMetaKey, "Revision").Result()
	if err != nil {
		return nil, xerrors.Errorf("Failed to HGet Revision. err: %w", err)
	}

	verstr, err := r.conn.HGet(ctx, fetchMetaKey, "SchemaVersion").Result()
	if err != nil {
		return nil, xerrors.Errorf("Failed to HGet SchemaVersion. err: %w", err)
	}
	version, err := strconv.ParseUint(verstr, 10, 8)
	if err != nil {
		return nil, xerrors.Errorf("Failed to ParseUint. err: %w", err)
	}

	datestr, err := r.conn.HGet(ctx, fetchMetaKey, "LastFetchedAt").Result()
	if err != nil {
		if !errors.Is(err, redis.Nil) {
			return nil, xerrors.Errorf("Failed to HGet LastFetchedAt. err: %w", err)
		}
		datestr = time.Date(1000, time.January, 1, 0, 0, 0, 0, time.UTC).Format(time.RFC3339)
	}
	date, err := time.Parse(time.RFC3339, datestr)
	if err != nil {
		return nil, xerrors.Errorf("Failed to Parse date. err: %w", err)
	}

	return &models.FetchMeta{GoCVEDictRevision: revision, SchemaVersion: uint(version), LastFetchedAt: date}, nil
}

// UpsertFetchMeta upsert FetchMeta to Database
func (r *RedisDriver) UpsertFetchMeta(fetchMeta *models.FetchMeta) error {
	return r.conn.HSet(context.Background(), fetchMetaKey, map[string]interface{}{"Revision": config.Revision, "SchemaVersion": models.LatestSchemaVersion, "LastFetchedAt": fetchMeta.LastFetchedAt}).Err()
}

// Get Select Cve information from DB.
func (r *RedisDriver) Get(cveID string) (*models.CveDetail, error) {
	results, err := r.conn.HGetAll(context.Background(), fmt.Sprintf(cveKeyFormat, cveID)).Result()
	if err != nil {
		return nil, xerrors.Errorf("Failed to HGetAll. err: %w", err)
	}
	detail, err := convertToCveDetail(cveID, results)
	if err != nil {
		return nil, xerrors.Errorf("Failed to convertToCveDetail. err: %w", err)
	}
	return &detail, nil
}

func convertToCveDetail(cveID string, results map[string]string) (models.CveDetail, error) {
	detail := models.CveDetail{
		CveID: cveID,
		Nvds:  []models.Nvd{},
		Jvns:  []models.Jvn{},
	}

	if jsonStr, ok := results[models.NvdType]; ok {
		var nvd models.Nvd
		if err := json.Unmarshal([]byte(jsonStr), &nvd); err != nil {
			return models.CveDetail{}, xerrors.Errorf("Failed to Unmarshal JSON. err: %w", err)
		}
		detail.Nvds = append(detail.Nvds, nvd)
		delete(results, models.NvdType)
	}

	for field, jsonStr := range results {
		if !strings.HasPrefix(field, "JVNDB-") {
			log.Warnf("field(%s) obtained by %s is not in JVN format", field, cveID)
			continue
		}

		var jvn models.Jvn
		if err := json.Unmarshal([]byte(jsonStr), &jvn); err != nil {
			return models.CveDetail{}, xerrors.Errorf("Failed to Unmarshal JSON. err: %w", err)
		}
		detail.Jvns = append(detail.Jvns, jvn)
	}

	return detail, nil
}

// GetMulti Select Cves information from DB.
func (r *RedisDriver) GetMulti(cveIDs []string) (map[string]models.CveDetail, error) {
	ctx := context.Background()

	m := map[string]*redis.StringStringMapCmd{}
	pipe := r.conn.Pipeline()
	for _, cveID := range cveIDs {
		m[cveID] = pipe.HGetAll(ctx, fmt.Sprintf(cveKeyFormat, cveID))
	}
	if _, err := pipe.Exec(ctx); err != nil {
		return nil, xerrors.Errorf("Failed to exec pipeline. err: %w", err)
	}

	cveDetails := map[string]models.CveDetail{}
	for cveID, cmd := range m {
		results, err := cmd.Result()
		if err != nil {
			return nil, xerrors.Errorf("Failed to HGetAll. err: %w", err)
		}
		detail, err := convertToCveDetail(cveID, results)
		if err != nil {
			return nil, xerrors.Errorf("Failed to convertToCveDetail. err: %w", err)
		}
		cveDetails[cveID] = detail
	}

	return cveDetails, nil
}

// GetCveIDsByCpeURI Select Cve Ids by by pseudo-CPE
func (r *RedisDriver) GetCveIDsByCpeURI(uri string) ([]string, []string, error) {
	specified, err := naming.UnbindURI(uri)
	if err != nil {
		return nil, nil, err
	}
	cpeKey := fmt.Sprintf(cpeKeyFormat, fmt.Sprintf("%s#%s#%s", specified.Get(common.AttributePart), specified.Get(common.AttributeVendor), specified.Get(common.AttributeProduct)))

	cveIDs, err := r.conn.SMembers(context.Background(), cpeKey).Result()
	if err != nil {
		return nil, nil, xerrors.Errorf("Failed to SMembers. err: %w", err)
	}

	cveDetails, err := r.GetMulti(cveIDs)
	if err != nil {
		return nil, nil, xerrors.Errorf("Failed to GetMulti. err: %w", err)
	}

	nvdCveIDs := []string{}
	jvnCveIDs := []string{}
	for _, detail := range cveDetails {
		if err := filterCveDetailByCpeURI(uri, &detail); err != nil {
			return nil, nil, err
		}
		nvdMatch, jvnMatch, err := matchCpe(uri, &detail)
		if err != nil {
			log.Warnf("Failed to compare the version:%s %s %#v", err, uri, &detail)
			// continue matching
			continue
		}
		if nvdMatch {
			nvdCveIDs = append(nvdCveIDs, detail.CveID)
		} else if jvnMatch {
			jvnCveIDs = append(jvnCveIDs, detail.CveID)
		}
	}

	return nvdCveIDs, jvnCveIDs, nil
}

// GetByCpeURI Select Cve information from DB.
func (r *RedisDriver) GetByCpeURI(uri string) ([]models.CveDetail, error) {
	specified, err := naming.UnbindURI(uri)
	if err != nil {
		return nil, err
	}
	cpeKey := fmt.Sprintf(cpeKeyFormat, fmt.Sprintf("%s#%s#%s", specified.Get(common.AttributePart), specified.Get(common.AttributeVendor), specified.Get(common.AttributeProduct)))

	cveIDs, err := r.conn.SMembers(context.Background(), cpeKey).Result()
	if err != nil {
		return nil, xerrors.Errorf("Failed to SMembers. err: %w", err)
	}

	cveDetails, err := r.GetMulti(cveIDs)
	if err != nil {
		return nil, xerrors.Errorf("Failed to GetMulti. err: %w", err)
	}

	details := []models.CveDetail{}
	for _, detail := range cveDetails {
		if err := filterCveDetailByCpeURI(uri, &detail); err != nil {
			return nil, err
		}
		if len(detail.Nvds) > 0 || len(detail.Jvns) > 0 {
			details = append(details, detail)
		}
	}

	return details, nil
}

// CountJvn count jvn table
func (r *RedisDriver) CountJvn() (int, error) {
	depstr, err := r.conn.HGet(context.Background(), depKey, models.JvnType).Result()
	if err != nil {
		if errors.Is(err, redis.Nil) {
			return 0, nil
		}
		return 0, err
	}

	// deps: {"JVNID": {"CVEID": {"part#vendor#product": {}}}}
	var deps map[string]map[string]struct{}
	if err := json.Unmarshal([]byte(depstr), &deps); err != nil {
		return 0, xerrors.Errorf("Failed to unmarshal JSON. err: %w", err)
	}

	return len(deps), nil
}

// InsertJvn insert items fetched from JVN.
func (r *RedisDriver) InsertJvn(years []string) error {
	ctx := context.Background()
	batchSize := viper.GetInt("batch-size")
	if batchSize < 1 {
		return fmt.Errorf("Failed to set batch-size. err: batch-size option is not set properly")
	}
	var err error

	// {"year", "recent" or "modified": { "JVNID#CVE-ID":Jvn{} } }
	uniqCves := map[string]map[string]models.Jvn{}

	log.Infof("Fetching CVE information from JVN(recent, modified).")
	if err := jvn.FetchConvert(uniqCves, []string{"recent", "modified"}); err != nil {
		return xerrors.Errorf("Failed to FetchConvert. err: %w", err)
	}

	// newDeps, oldDeps: {"JVNID": {"CVEID": {"part#vendor#product": {}}}}
	newDeps := map[string]map[string]map[string]struct{}{}
	oldDepsStr, err := r.conn.HGet(ctx, depKey, models.JvnType).Result()
	if err != nil {
		if !errors.Is(err, redis.Nil) {
			return xerrors.Errorf("Failed to Get key: %s. err: %w", depKey, err)
		}
		oldDepsStr = "{}"
	}
	var oldDeps map[string]map[string]map[string]struct{}
	if err := json.Unmarshal([]byte(oldDepsStr), &oldDeps); err != nil {
		return xerrors.Errorf("Failed to unmarshal JSON. err: %w", err)
	}

	if len(years) == 0 {
		for y := 1998; y <= time.Now().Year(); y++ {
			years = append(years, fmt.Sprintf("%d", y))
		}
	}
	for _, year := range years {
		log.Infof("Fetching CVE information from JVN(%s).", year)
		if err := jvn.FetchConvert(uniqCves, []string{year}); err != nil {
			return xerrors.Errorf("Failed to FetchConvert. err: %w", err)
		}

		cves := []models.Jvn{}
		for _, cve := range uniqCves[year] {
			cves = append(cves, cve)
		}
		delete(uniqCves, year)

		log.Infof("Inserting fetched CVEs(%s)...", year)
		bar := pb.StartNew(len(cves))
		for idx := range chunkSlice(len(cves), batchSize) {
			pipe := r.conn.Pipeline()
			for _, cve := range cves[idx.From:idx.To] {
				var jn []byte
				if jn, err = json.Marshal(cve); err != nil {
					return xerrors.Errorf("Failed to marshal json. err: %w", err)
				}

				_ = pipe.HSet(ctx, fmt.Sprintf(cveKeyFormat, cve.CveID), cve.JvnID, string(jn))
				if _, ok := newDeps[cve.JvnID]; !ok {
					newDeps[cve.JvnID] = map[string]map[string]struct{}{}
				}
				if _, ok := newDeps[cve.JvnID][cve.CveID]; !ok {
					newDeps[cve.JvnID][cve.CveID] = map[string]struct{}{}
				}

				for _, cpe := range cve.Cpes {
					cpePartVendorProductStr := fmt.Sprintf("%s#%s#%s", cpe.Part, cpe.Vendor, cpe.Product)
					_ = pipe.SAdd(ctx, fmt.Sprintf(cpeKeyFormat, cpePartVendorProductStr), cve.CveID)
					newDeps[cve.JvnID][cve.CveID][cpePartVendorProductStr] = struct{}{}
					if _, ok := oldDeps[cve.JvnID]; ok {
						if _, ok := oldDeps[cve.JvnID][cve.CveID]; ok {
							delete(oldDeps[cve.JvnID][cve.CveID], cpePartVendorProductStr)
							if len(oldDeps[cve.JvnID][cve.CveID]) == 0 {
								delete(oldDeps[cve.JvnID], cve.CveID)
							}
						}
					}
				}
				if _, ok := oldDeps[cve.JvnID]; ok {
					if len(oldDeps[cve.JvnID]) == 0 {
						delete(oldDeps, cve.JvnID)
					}
				}
			}
			if _, err = pipe.Exec(ctx); err != nil {
				return xerrors.Errorf("Failed to exec pipeline. err: %w", err)
			}
			bar.Add(idx.To - idx.From)
		}
		bar.Finish()
		log.Infof("Refreshed %d CVEs.", len(cves))
	}

	pipe := r.conn.Pipeline()
	for jvnID, cves := range oldDeps {
		for cveID, cpes := range cves {
			for cpePartVendorProductStr := range cpes {
				_ = pipe.SRem(ctx, fmt.Sprintf(cpeKeyFormat, cpePartVendorProductStr), cveID)
			}
			if _, ok := newDeps[jvnID]; !ok {
				if _, ok := newDeps[jvnID][cveID]; !ok {
					_ = pipe.HDel(ctx, fmt.Sprintf(cveKeyFormat, cveID), jvnID)
				}
			}
		}
	}
	newDepsJSON, err := json.Marshal(newDeps)
	if err != nil {
		return xerrors.Errorf("Failed to Marshal JSON. err: %w", err)
	}
	_ = pipe.HSet(ctx, depKey, models.JvnType, string(newDepsJSON))
	if _, err = pipe.Exec(ctx); err != nil {
		return xerrors.Errorf("Failed to exec pipeline. err: %w", err)
	}

	return nil
}

// CountNvd count nvd table
func (r *RedisDriver) CountNvd() (int, error) {
	depstr, err := r.conn.HGet(context.Background(), depKey, models.NvdType).Result()
	if err != nil {
		if errors.Is(err, redis.Nil) {
			return 0, nil
		}
		return 0, err
	}

	// deps: {"CVEID": {"part#vendor#product": {}}}
	var deps map[string]map[string]struct{}
	if err := json.Unmarshal([]byte(depstr), &deps); err != nil {
		return 0, xerrors.Errorf("Failed to unmarshal JSON. err: %w", err)
	}

	return len(deps), nil
}

// InsertNvd Cve information from DB.
func (r *RedisDriver) InsertNvd(years []string) error {
	ctx := context.Background()
	batchSize := viper.GetInt("batch-size")
	if batchSize < 1 {
		return fmt.Errorf("Failed to set batch-size. err: batch-size option is not set properly")
	}
	var err error

	// {"year", "recent" or "modified": { "CVE-ID": Nvd{} } }
	uniqCves := map[string]map[string]models.Nvd{}

	log.Infof("Fetching CVE information from NVD(recent, modified).")
	if err := nvd.FetchConvert(uniqCves, []string{"recent", "modified"}); err != nil {
		return xerrors.Errorf("Failed to FetchConvert. err: %w", err)
	}

	// newDeps, oldDeps: {"CVEID": {"part#vendor#product": {}}}
	newDeps := map[string]map[string]struct{}{}
	oldDepsStr, err := r.conn.HGet(ctx, depKey, models.NvdType).Result()
	if err != nil {
		if !errors.Is(err, redis.Nil) {
			return xerrors.Errorf("Failed to Get key: %s. err: %w", depKey, err)
		}
		oldDepsStr = "{}"
	}
	var oldDeps map[string]map[string]struct{}
	if err := json.Unmarshal([]byte(oldDepsStr), &oldDeps); err != nil {
		return xerrors.Errorf("Failed to unmarshal JSON. err: %w", err)
	}

	if len(years) == 0 {
		for y := 2002; y <= time.Now().Year(); y++ {
			years = append(years, fmt.Sprintf("%d", y))
		}
	}
	for _, year := range years {
		log.Infof("Fetching CVE information from NVD(%s).", year)
		if err := nvd.FetchConvert(uniqCves, []string{year}); err != nil {
			return xerrors.Errorf("Failed to FetchConvert. err: %w", err)
		}

		cves := []models.Nvd{}
		for _, cve := range uniqCves[year] {
			cves = append(cves, cve)
		}
		delete(uniqCves, year)

		log.Infof("Inserting fetched CVEs(%s)...", year)
		bar := pb.StartNew(len(cves))
		for idx := range chunkSlice(len(cves), batchSize) {
			pipe := r.conn.Pipeline()
			for _, cve := range cves[idx.From:idx.To] {
				var jn []byte
				if jn, err = json.Marshal(cve); err != nil {
					return xerrors.Errorf("Failed to marshal json. err: %w", err)
				}

				_ = pipe.HSet(ctx, fmt.Sprintf(cveKeyFormat, cve.CveID), models.NvdType, string(jn))
				if _, ok := newDeps[cve.CveID]; !ok {
					newDeps[cve.CveID] = map[string]struct{}{}
				}

				for _, cpe := range cve.Cpes {
					cpePartVendorProductStr := fmt.Sprintf("%s#%s#%s", cpe.Part, cpe.Vendor, cpe.Product)
					_ = pipe.SAdd(ctx, fmt.Sprintf(cpeKeyFormat, cpePartVendorProductStr), cve.CveID)
					newDeps[cve.CveID][cpePartVendorProductStr] = struct{}{}
					if _, ok := oldDeps[cve.CveID]; ok {
						delete(oldDeps[cve.CveID], cpePartVendorProductStr)
					}
				}
				if _, ok := oldDeps[cve.CveID]; ok {
					if len(oldDeps[cve.CveID]) == 0 {
						delete(oldDeps, cve.CveID)
					}
				}
			}
			if _, err = pipe.Exec(ctx); err != nil {
				return xerrors.Errorf("Failed to exec pipeline. err: %w", err)
			}
			bar.Add(idx.To - idx.From)
		}
		bar.Finish()
		log.Infof("Refreshed %d CVEs.", len(cves))
	}

	pipe := r.conn.Pipeline()
	for cveID, cpes := range oldDeps {
		for cpePartVendorProductStr := range cpes {
			_ = pipe.SRem(ctx, fmt.Sprintf(cpeKeyFormat, cpePartVendorProductStr), cveID)
		}
		if _, ok := newDeps[cveID]; !ok {
			_ = pipe.HDel(ctx, fmt.Sprintf(cveKeyFormat, cveID), models.NvdType)
		}
	}
	newDepsJSON, err := json.Marshal(newDeps)
	if err != nil {
		return xerrors.Errorf("Failed to Marshal JSON. err: %w", err)
	}
	_ = pipe.HSet(ctx, depKey, models.NvdType, string(newDepsJSON))
	if _, err = pipe.Exec(ctx); err != nil {
		return xerrors.Errorf("Failed to exec pipeline. err: %w", err)
	}

	return nil
}
